﻿#include <list>
#include <unordered_set>

#include "bytearraystream.h"
#include "tcpsocket.h"
#include "datastream.h"
#include "messagechannel.h"

static std::unordered_set<MessageFactory*> g_factorys;
static std::mutex g_mutex;

enum Mode{
    None,
    Host,
    Client
};

struct ClientData{
    ClientId id;
    std::shared_ptr<TcpSocket> socket;
    ByteArray buffer;
};

struct Channel::Private{
    Private(){
        server = NULL;
        host = NULL;
        mode = None;
        maxId = 1;
        valid = false;
    }

    ~Private(){
    }

    TcpServer* server;
    ClientId maxId;
    std::unordered_map<ClientId,ClientData> clients;

    TcpSocket* host;
    ByteArray buffer;
    bool valid;

    std::list<MessagePtr> messageQueue;

    Mode mode;
    std::recursive_mutex mutex;

    ConnectionRegistry reg;
};

Channel::Channel():_P(new Private)
{

}

Channel::~Channel()
{
    close();
}

void Channel::Bind(int port)
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);
    if(_P->mode != None){return;}
    _P->mode = Host;
    _P->server = new TcpServer();
    _P->server->Bind(port);
    _P->server->NewConnection.Bind([=](){
        slotOnNewConnection();
    },this);
}

void Channel::ConnectTo(std::string host,int port)
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);
    if(_P->mode != None){return;}
    _P->mode = Client;
    _P->host = new TcpSocket();
    _P->host->Bind();
    _P->host->ConnectTo(host,port);
    _P->host->ReadyRead.Bind([=](){
        slotOnHostMessage();
    },this);
    _P->host->Disconnected.Bind([=](){
        slotOnHostDisconnected();
    }, this);
}

void Channel::close()
{

}

void Channel::SendMessage(Message* message)
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);
    if(_P->mode == None){return;}

    ByteArray m = Message::Serilize(message);
    if(_P->mode == Host){
        for(auto i:_P->clients){
            i.second.socket->write((const unsigned char*)m.data(),m.size());
        }
    }
    else if(_P->mode == Client){
        _P->host->write((const unsigned char*)m.data(),m.size());
    }
}

void Channel::SendMessage(Message* message, ClientId id)
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);
    if(_P->mode != Host){return;}

	ByteArray m = Message::Serilize(message);
    if(_P->clients.count(id)){
        _P->clients[id].socket->write((const unsigned char*)m.data(),m.size());
    }
}

bool Channel::hasMessage() const
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);
    return  !_P->messageQueue.empty();
}

MessagePtr Channel::nextMessage()
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);

    if(!_P->messageQueue.empty()){
        MessagePtr m = _P->messageQueue.front();
        _P->messageQueue.pop_front();
        return m;
    }

    return NULL;
}

void Channel::slotOnClientMessage(ClientId id)
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);

    if(!_P->clients.count(id)){
        return;
    }

    auto & client = _P->clients[id];

	auto readAll = client.socket->readAll();
	client.buffer.insert(client.buffer.begin(), readAll.begin(), readAll.end());

    bool flag = false;

    while (true) {
		Message* m = NULL;
        try {
            m = Message::UnSerilize(client.buffer);
        } catch (const MessageException& e) {
            close();
            return ;
        }
        if(!m){
            break;
        }

        m->source = id;
        flag = true;
        _P->messageQueue.push_back(MessagePtr(m));
    }

    if(flag){
        NewMessage();
    }
}

void Channel::slotOnNewConnection()
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);
    ClientId id = _P->maxId;

    while (_P->server->hasPendingConnection()) {
        TcpSocket * socket = _P->server->pendingConnection();
        socket->ReadyRead.Bind([=](){slotOnClientMessage(id);},this);

        ClientData data;

        data.id = id;
        data.socket.reset(socket);

        _P->clients[id] = data;

		slotOnClientMessage(id);
    }
}

void Channel::slotOnHostMessage()
{
    std::lock_guard<std::recursive_mutex> lock(_P->mutex);


    _P->buffer += _P->host->readAll();

    bool flag = false;

    while (true) {
        Message* m = NULL;
        try {
            m = Message::UnSerilize(_P->buffer);
        } catch (const MessageException& e) {
            close();
            return ;
        }

        if(!m){
            break;
        }
        flag = true;
        _P->messageQueue.push_back(MessagePtr(m));
    }

    if(flag){
        NewMessage();
    }
}

void Channel::slotOnHostDisconnected()
{

}

ByteArray Message::Serilize(Message *message)
{
	ByteArray content;
	{
		BytearrayOutputStream btao;
		DataOutputStream s(&btao);

		message->Serilize(s);
		content = btao.GetByteArray();
	}

	ByteArray ret;
	{
		BytearrayOutputStream btao;
		DataOutputStream s(&btao);

		s << content;

		content = btao.GetByteArray();
	}

	return ret;
}

Message *Message::UnSerilize(ByteArray & str) throw (MessageException)
{
    Message* ret = NULL;

	if (str.size() <= 4) {
		return NULL;
	}

	const unsigned char* buffer = str.data();
	int len = str.size();

	BytearrayInputStream btai(buffer, len);

    DataInputStream s(&btai);

    int size;
    s >> size;
    if(size > s.avaliable()){
        return NULL;//包长度不够 不能解析
    }

    MessageType type = L"";

    s >> type;

    for(auto i:g_factorys){
        ret = i->CreateMessage(type);
        if(ret){
            break;
        }
    }

    if(!ret){
        return NULL;
    }

    ret->type = type;
    ret->UnSerilize(s);

	str.erase(str.begin(), str.begin() + s.byteReaded());
	return ret;
}

void Message::Serilize(DataOutputStream &s)
{
    s<<type;
    s<<source;
}

void Message::UnSerilize(DataInputStream &s) throw (MessageException)
{
    s >> source;
}

void Message::RegisterFactory(MessageFactory *factory)
{
    std::lock_guard<std::mutex> lock(g_mutex);
    g_factorys.insert(factory);
}

void Message::UnRegisterFactory(MessageFactory *factory)
{
    std::lock_guard<std::mutex> lock(g_mutex);
    g_factorys.erase(factory);
}
